{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp problem_types.pretrain\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test setup\n",
    "# import tensorflow as tf\n",
    "# import numpy as np\n",
    "# from m3tl.test_base import TestBase\n",
    "# from m3tl.input_fn import train_eval_input_fn\n",
    "# from m3tl.test_base import test_top_layer\n",
    "# test_base = TestBase()\n",
    "# params = test_base.params\n",
    "\n",
    "# params.assign_problem('weibo_fake_ner&weibo_fake_cls|weibo_fake_multi_cls|weibo_masklm|weibo_pretrain',base_dir=test_base.tmpckptdir)\n",
    "\n",
    "# hidden_dim = params.bert_config.hidden_size\n",
    "\n",
    "# train_dataset = train_eval_input_fn(params=params)\n",
    "# one_batch = next(train_dataset.as_numpy_iterator())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pretrain(pretrain)\n",
    "\n",
    "This module includes neccessary part to register pretrain problem type."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "from typing import Dict, List, Tuple\n",
    "\n",
    "from loguru import logger\n",
    "import tensorflow as tf\n",
    "import transformers\n",
    "from m3tl.base_params import BaseParams\n",
    "from m3tl.problem_types.utils import empty_tensor_handling_loss\n",
    "from m3tl.special_tokens import PREDICT\n",
    "from m3tl.utils import (LabelEncoder, gather_indexes, get_phase,\n",
    "                        variable_summaries)\n",
    "from transformers import TFSharedEmbeddings\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Top Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "class PreTrain(tf.keras.Model):\n",
    "    def __init__(self, params: BaseParams, problem_name: str, input_embeddings: tf.Tensor=None, share_embedding=True):\n",
    "        super(PreTrain, self).__init__(name=problem_name)\n",
    "        self.problem_name = problem_name\n",
    "        self.params = params\n",
    "        self.nsp = transformers.models.bert.modeling_tf_bert.TFBertNSPHead(\n",
    "            self.params.bert_config)\n",
    "\n",
    "        if share_embedding is False:\n",
    "            self.vocab_size = self.params.bert_config.vocab_size\n",
    "            self.share_embedding = False\n",
    "        else:\n",
    "            word_embedding_weight = input_embeddings.word_embeddings\n",
    "            self.vocab_size = word_embedding_weight.shape[0]\n",
    "            embedding_size = word_embedding_weight.shape[-1]\n",
    "            share_valid = (self.params.bert_config.hidden_size ==\n",
    "                        embedding_size)\n",
    "            if not share_valid and self.params.share_embedding:\n",
    "                logger.warning(\n",
    "                    'Share embedding is enabled but hidden_size != embedding_size')\n",
    "            self.share_embedding = self.params.share_embedding & share_valid\n",
    "\n",
    "        if self.share_embedding:\n",
    "            self.share_embedding_layer = TFSharedEmbeddings(\n",
    "                vocab_size=word_embedding_weight.shape[0], hidden_size=word_embedding_weight.shape[1])\n",
    "            self.share_embedding_layer.build([1])\n",
    "            self.share_embedding_layer.weight = word_embedding_weight\n",
    "        else:\n",
    "            self.share_embedding_layer = tf.keras.layers.Dense(self.vocab_size)\n",
    "\n",
    "    def call(self,\n",
    "             inputs: Tuple[Dict[str, Dict[str, tf.Tensor]], Dict[str, Dict[str, tf.Tensor]]]) -> Tuple[tf.Tensor, tf.Tensor]:\n",
    "        mode = get_phase()\n",
    "        features, hidden_features = inputs\n",
    "\n",
    "        # compute logits\n",
    "        nsp_logits = self.nsp(hidden_features['pooled'])\n",
    "\n",
    "        # masking is done inside the model\n",
    "        seq_hidden_feature = hidden_features['seq']\n",
    "        if mode != PREDICT:\n",
    "            positions = features['masked_lm_positions']\n",
    "\n",
    "            # gather_indexes will flatten the seq hidden_states, we need to reshape\n",
    "            # back to 3d tensor\n",
    "            input_tensor = gather_indexes(seq_hidden_feature, positions)\n",
    "            shape_tensor = tf.shape(positions)\n",
    "            shape_list = tf.concat(\n",
    "                [shape_tensor, [seq_hidden_feature.shape.as_list()[-1]]], axis=0)\n",
    "            input_tensor = tf.reshape(input_tensor, shape=shape_list)\n",
    "            # set_shape to determin rank\n",
    "            input_tensor.set_shape(\n",
    "                [None, None, seq_hidden_feature.shape.as_list()[-1]])\n",
    "        else:\n",
    "            input_tensor = seq_hidden_feature\n",
    "        if self.share_embedding:\n",
    "            mlm_logits = self.share_embedding_layer(\n",
    "                input_tensor, mode='linear')\n",
    "        else:\n",
    "            mlm_logits = self.share_embedding_layer(input_tensor)\n",
    "\n",
    "        if self.params.detail_log:\n",
    "            for weight_variable in self.weights:\n",
    "                variable_summaries(weight_variable, self.problem_name)\n",
    "\n",
    "        if mode != PREDICT:\n",
    "            nsp_labels = features['next_sentence_label_ids']\n",
    "            mlm_labels = features['masked_lm_ids']\n",
    "            mlm_labels.set_shape([None, None])\n",
    "            # compute loss\n",
    "            nsp_loss = empty_tensor_handling_loss(\n",
    "                nsp_labels, nsp_logits,\n",
    "                tf.keras.losses.sparse_categorical_crossentropy)\n",
    "            mlm_loss_layer = transformers.modeling_tf_utils.TFMaskedLanguageModelingLoss()\n",
    "            # mlm_loss = tf.reduce_mean(\n",
    "            #     mlm_loss_layer.compute_loss(mlm_labels, mlm_logits))\n",
    "\n",
    "            # add a useless from_logits argument to match the function signature of keras losses.\n",
    "            def loss_fn_wrapper(labels, logits, from_logits=True):\n",
    "                return mlm_loss_layer.compute_loss(labels, logits)\n",
    "            mlm_loss = empty_tensor_handling_loss(\n",
    "                mlm_labels,\n",
    "                mlm_logits,\n",
    "                loss_fn_wrapper\n",
    "            )\n",
    "            loss = nsp_loss + mlm_loss\n",
    "            self.add_loss(loss)\n",
    "\n",
    "        return (tf.sigmoid(nsp_logits), tf.nn.softmax(mlm_logits))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test_top_layer(PreTrain, problem='weibo_fake_cls', params=params, sample_features=one_batch, hidden_dim=hidden_dim, share_embedding=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get or make label encoder function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def pretrain_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:\n",
    "    params.set_problem_info(problem=problem, info_name='num_classes', info=1)\n",
    "    return None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Label handing function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def pretrain_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):\n",
    "    return None, None\n",
    "\n"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
